import argparse
import os
import torch
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from tqdm import tqdm

torch.set_grad_enabled(True)

def parse_blocks(s):
    full = ["q","k","v","o","up","down","gate"]
    if s.lower() == "all":
        return full
    return [x.strip() for x in s.split(',') if x.strip() in full]

def tag2mod(layer):
    d = {}
    if hasattr(layer, "self_attn"):
        attn = layer.self_attn
        if hasattr(attn, "q_proj"): d["q"] = attn.q_proj
        if hasattr(attn, "k_proj"): d["k"] = attn.k_proj
        if hasattr(attn, "v_proj"): d["v"] = attn.v_proj
        if hasattr(attn, "o_proj"): d["o"] = attn.o_proj
    if hasattr(layer, "mlp"):
        mlp = layer.mlp
        if hasattr(mlp, "up_proj"): d["up"] = mlp.up_proj
        if hasattr(mlp, "down_proj"): d["down"] = mlp.down_proj
        if hasattr(mlp, "gate_proj"): d["gate"] = mlp.gate_proj
    return d

@torch.no_grad()
def teacher_probs(teacher, input_ids, attn_mask, T):
    out = teacher(input_ids=input_ids, attention_mask=attn_mask, use_cache=False)
    return torch.softmax(out.logits / T, dim=-1)

def kd_loss(student, input_ids, attn_mask, p_t, T):
    out = student(input_ids=input_ids, attention_mask=attn_mask, use_cache=False)
    log_p_s = torch.log_softmax(out.logits / T, dim=-1)
    return torch.nn.functional.kl_div(log_p_s, p_t, reduction="batchmean") * (T*T)

def build_tokens(tokenizer, dataset, nsamples, seqlen, device):
    texts = dataset["train"]["text"]
    text = "\n\n".join(texts)
    toks = tokenizer(text, return_tensors="pt")["input_ids"]
    toks = toks[:, : (nsamples * seqlen)]
    if toks.shape[1] < nsamples * seqlen:
        pad = nsamples * seqlen - toks.shape[1]
        toks = torch.nn.functional.pad(toks, (0, pad), value=tokenizer.eos_token_id or 0)
    toks = toks.view(1, nsamples, seqlen)[0].to(device)
    return toks

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--teacher_dir", required=True)
    ap.add_argument("--student_dir", required=True)
    ap.add_argument("--out_grad_dir", required=True)
    ap.add_argument("--dataset", type=str, default="wikitext2")
    ap.add_argument("--nsamples", type=int, default=64)
    ap.add_argument("--seqlen", type=int, default=512)
    ap.add_argument("--temperature", type=float, default=1.0)
    ap.add_argument("--device", type=str, default="cuda")
    ap.add_argument("--blocks", type=str, default="q,k,v,o,up,down,gate")
    ap.add_argument("--batch_size", type=int, default=1)
    args = ap.parse_args()
    os.makedirs(args.out_grad_dir, exist_ok=True)
    blocks = parse_blocks(args.blocks)
    tok = AutoTokenizer.from_pretrained(args.teacher_dir, use_fast=True)
    teacher = AutoModelForCausalLM.from_pretrained(args.teacher_dir, torch_dtype=torch.bfloat16, device_map=args.device)
    student = AutoModelForCausalLM.from_pretrained(args.student_dir, torch_dtype=torch.bfloat16, device_map=args.device)
    teacher.eval()
    student.train()
    student.config.use_cache = False
    if hasattr(student, "gradient_checkpointing_enable"):
        student.gradient_checkpointing_enable()
    for layer in student.model.layers:
        for tag, mod in tag2mod(layer).items():
            mod.weight.requires_grad_(tag in blocks)
    for p in student.parameters():
        if not p.requires_grad:
            p.grad = None
    buffers = {}
    for i, layer in enumerate(student.model.layers):
        for tag, mod in tag2mod(layer).items():
            if tag in blocks:
                buffers[(i, tag)] = torch.zeros_like(mod.weight, dtype=torch.float32, device="cpu")
    if args.dataset.lower() in ("wikitext2", "wikitext-2"):
        ds = load_dataset("wikitext", "wikitext-2-raw-v1")
    else:
        ds = load_dataset(args.dataset)
    tokens = build_tokens(tok, ds, args.nsamples, args.seqlen, args.device)
    bs = max(1, int(args.batch_size))
    T = float(args.temperature)
    total = 0
    for s in tqdm(range(0, args.nsamples, bs), desc="collect_kd"):
        input_ids = tokens[s:s+bs]
        attn_mask = torch.ones_like(input_ids)
        with torch.no_grad():
            p_t = teacher_probs(teacher, input_ids, attn_mask, T)
        for p in student.parameters():
            if p.grad is not None:
                p.grad = None
        loss = kd_loss(student, input_ids, attn_mask, p_t, T)
        loss.backward()
        for i, layer in enumerate(student.model.layers):
            for tag, mod in tag2mod(layer).items():
                if tag in blocks and mod.weight.grad is not None:
                    buffers[(i, tag)] += mod.weight.grad.detach().to("cpu")
        total += input_ids.shape[0]
        del input_ids, attn_mask, p_t, loss
        torch.cuda.empty_cache()
    for (i, tag), G in buffers.items():
        G /= float(total)
        torch.save(G, os.path.join(args.out_grad_dir, f"{i}_{tag}.pt"))

if __name__ == "__main__":
    try:
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True
    except Exception:
        pass
    main()
